"""This module provides functionality to estimate confidence intervals via bootstrapping.

Functions in this module should be considered experimental, meaning there might be breaking API changes in the future.
"""

from typing import Any, Callable, Dict, List, Tuple, Union

import numpy as np
from joblib import Parallel, delayed
from scipy.optimize import minimize
from tqdm import tqdm

from dowhy.gcm import config
from dowhy.gcm.util.general import set_random_seed, shape_into_2d


def estimate_geometric_median(X: np.ndarray) -> np.ndarray:
    def distance_function(x_input: np.ndarray) -> np.ndarray:
        return np.sum(np.sqrt(np.sum((x_input - X) ** 2, axis=1)))

    return minimize(distance_function, np.sum(X, axis=0) / X.shape[0]).x


def confidence_intervals(
    estimation_func: Union[Callable[[], np.ndarray], Callable[[], Dict[Any, float]]],
    confidence_level: float = 0.95,
    num_bootstrap_resamples: int = 20,
    bootstrap_results_summary_func: Callable[[np.ndarray], np.ndarray] = estimate_geometric_median,
    n_jobs: int = 1,
) -> Tuple[Union[np.ndarray, Dict[Any, np.ndarray]], Union[np.ndarray, Dict[Any, np.ndarray]]]:
    """
    Estimates confidence intervals based on the outputs generated by calling the given estimation_func. Since one result
    for each repetition is produced, all results can be summarized by the method defined in
    summary_method_of_bootstrap_results.
    For instance, summary_method_of_bootstrap_results = lambda x: numpy.mean(x, axis=0) to get the mean over all runs.
    By default, the geometric median is returned.

    Currently, the confidence intervals are empirically estimated based on the n-th estimated quantiles (without bias
    correction) of the results, where the quantiles are determined by the given confidence_level.

    NOTE: The outputs of estimation_func are assumed to be pairwise independent. For multidimensional outputs of
    estimation_func, this could be violated and should be kept in mind. For instance, when evaluating the outcome of
    interventions in a graph like X -> Y -> Z, the confidence intervals are estimate independently for X, Y and Z
    although they have a strong dependency. If estimation_func returns one dimensional results, as for instance when
    estimating the direct arrow strength, then there should be no problem.

    **Example usage with numpy array output:**

        >>> def estimation_func() -> np.ndarray:
        >>>     return direct_arrow_strength_of_model(causal_model, parent_data)
        >>>
        >>> arrow_strengths, confidence_intervals = confidence_intervals(estimation_func)

    **Example usage with dictionary output:**

        >>> def estimation_func() -> Dict[Any, float]:
        >>>     return distribution_change(
        >>>             causal_dag, original_observations, outlier_observations, 'X3')
        >>>
        >>> mean_contributions, confidence_intervals = confidence_intervals(estimation_func)

    More details about the estimation of confidence intervals via bootstrapping can be found `here <https://ocw.mit.edu/courses/mathematics/18-05-introduction-to-probability-and-statistics-spring-2014/readings/MIT18_05S14_Reading24.pdf>`_.

    :param estimation_func: Function that generates a non-deterministic output for which the confidence interval(s) are
           estimated.
    :param confidence_level: Confidence level of the interval.
    :param num_bootstrap_resamples: Number of samples generated by estimation_func, i.e. number of times
           is called. The higher the number, the more accurate the results and intervals,
           but the slower the runtime.
    :param bootstrap_results_summary_func: Function that takes a numpy array with all results as an input and returns
           a single (potentially multidimensional) value/vector. For instance, the mean or median
           over all results.
    :param n_jobs: Number of parallel jobs. Each repetition can be estimated in parallel.
           However, since many other functions of the library are already running in parallel (
           such as distribution change), this is set to 1 by default. Only if it is certain that
           the estimation_func is not running in parallel internally (e.g. when performing
           interventions), this should be set to a different value.
    :return: A tuple (summarized result over all repetitions based on
             summary_method_of_bootstrap_results, confidence interval for each dimension/variable)
    """
    if num_bootstrap_resamples < 1:
        raise ValueError("Number of repetitions should be greater than 0, but got %d" % num_bootstrap_resamples)

    def estimation_func_with_random_seed(random_seed: int) -> Union[np.ndarray, Dict[Any, float]]:
        set_random_seed(random_seed)
        return estimation_func()

    random_seeds = np.random.randint(np.iinfo(np.int32).max, size=num_bootstrap_resamples)
    all_results = Parallel(n_jobs=n_jobs)(
        delayed(estimation_func_with_random_seed)(random_seed)
        for i, random_seed in enumerate(
            tqdm(
                random_seeds,
                position=0,
                leave=True,
                disable=not config.show_progress_bars,
                desc="Estimating boostrap interval...",
            )
        )
    )

    if isinstance(all_results[0], dict):
        all_results: List[Dict[Any, float]]
        tmp_dict = {}
        for result in all_results:
            for key, value in result.items():
                tmp_dict.setdefault(key, []).append(value)

        summary_result = bootstrap_results_summary_func(np.column_stack([v for v in tmp_dict.values()]))

        return {key: summary_result[i] for i, key in enumerate(tmp_dict)}, {
            key: _estimate_percentile_bounds(np.array(tmp_dict[key]).squeeze(), confidence_level) for key in tmp_dict
        }
    else:
        all_results: np.ndarray = shape_into_2d(np.array(all_results))

        return bootstrap_results_summary_func(all_results), np.array(
            [_estimate_percentile_bounds(all_results[:, i], confidence_level) for i in range(all_results.shape[1])]
        )


def _estimate_percentile_bounds(X: np.ndarray, quantile: float) -> np.ndarray:
    if X.ndim > 1:
        raise ValueError("Estimate bounds currently only supports one dimensional inputs!")

    return np.array([np.percentile(X, (1 - quantile) * 100), np.percentile(X, quantile * 100)])
